import paddle

from src.model.ssrnet import SSRNet
from src.model.sshead import SSHead
from src.model.ssloss import SSLoss
from src.model.sspred import SSPred

class SSYOLO(paddle.nn.Layer):
    def __init__(self, num_classes=20):
        """
        初始化检测网络
        params:
        - num_classes: 物体类别数量
        """
        super().__init__()
        
        # 骨干网络
        self.backbone = SSRNet(group_arch=[[3, 64, 256, 3, 8, 2], [128, 64, 256, 1, 4, 2], [128, 64, 256, 1, 2, 2]], block_mode='ssr')

        # 检测头部
        self.det_head = SSHead(num_classes=num_classes, in_channels=[256, 256, 256])

        # 计算损失
        self.get_loss = SSLoss(num_classes=num_classes, out_strides=[32, 16, 8])

        # 计算预测
        self.get_pred = SSPred(num_classes=num_classes)
        
    def forward(self, images):
        """
        提取卷积特征
        params:
        - images: 输入图像
        return:
        - p_list: 预测特征
        """
        c_list = self.backbone(images) # 提取骨干网络特征
        p_list = self.det_head(c_list) # 提取检测头部特征
        
        return p_list